"""
Phase 9: Round 2 Referee Revisions
Development Threshold Paper

1. Working-age share / youth dependency in within-zone savings regressions
2. Fine-Gray competing risk model (cause-specific Cox + multinomial logit)
3. Alternative threshold crossing counts reconciliation
"""

import pandas as pd
import numpy as np
from scipy import stats
import statsmodels.api as sm
from statsmodels.duration.hazard_regression import PHReg
import sys
sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')
from model import PanelGLS

# ═══════════════════════════════════════════════════════════════════════
# Load data
# ═══════════════════════════════════════════════════════════════════════
full = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
panel = full[full['year'] <= 2024].copy()

try:
    rents = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/raw/wdi_resource_rents.csv')
    panel = panel.merge(rents[['iso3', 'year', 'resource_rents_gdp']], on=['iso3', 'year'], how='left')
except:
    panel['resource_rents_gdp'] = np.nan

cross_df = pd.read_csv('/mnt/c/demographics_capital_flows/development_threshold/data/crossing_data.csv')
LOWER = 9000
UPPER = 25000

zone_entrants = cross_df[cross_df['cross_9k'].notna() & (cross_df['status'] != 'Always above')].copy()

print(f"Zone entrants: {len(zone_entrants)}")
print(f"  Crossed: {zone_entrants['exited_above'].sum()}")
print(f"  Did not cross: {(zone_entrants['exited_above'] == 0).sum()}")

# Collection for output table
round2_results = []

# ═══════════════════════════════════════════════════════════════════════
# 1. WORKING-AGE SHARE / YOUTH DEPENDENCY IN WITHIN-ZONE SAVINGS
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("1. WITHIN-ZONE SAVINGS REGRESSIONS: ALTERNATIVE DEMOGRAPHIC MEASURES")
print("=" * 70)
print("Referee concern: Z₁ → savings = -2.11 within the zone has awkward sign")
print("(more mature = less savings). Using working_age_share or youth_dep directly")
print("may clarify the mechanism.\n")

# Restrict panel to within-zone country-years
zone_panel = panel[(panel['gdp_pc_ppp'] >= LOWER) & (panel['gdp_pc_ppp'] <= UPPER)].copy()
print(f"Within-zone panel: {len(zone_panel)} country-years, "
      f"{zone_panel['iso3'].nunique()} countries")

# Check available variables
for v in ['Z_1', 'working_age_share', 'youth_dep', 'old_dep',
          'gross_savings_gdp', 'ca_gdp', 'fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']:
    n = zone_panel[v].notna().sum() if v in zone_panel.columns else 0
    print(f"  {v}: {n} non-null")

controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']

def run_panel_gls(df, y_var, demo_var, controls, label):
    """Run PanelGLS regression and return results dict."""
    all_vars = [y_var, demo_var] + controls + ['iso3', 'year']
    sample = df[all_vars].dropna().copy()
    if len(sample) < 30:
        print(f"  {label}: insufficient obs ({len(sample)})")
        return None

    y = sample[y_var].values
    X_cols = [demo_var] + controls
    X = sample[X_cols].values
    entity = sample['iso3'].values
    time = sample['year'].values

    gls = PanelGLS()
    gls.fit(y, X, entity, time)

    print(f"\n  {label} (n={gls.n_obs}, R²={gls.r_squared:.4f}):")
    result = {
        'analysis': 'Within-zone savings',
        'specification': label,
        'n': gls.n_obs,
        'r_squared': gls.r_squared,
    }
    for i, col in enumerate(X_cols):
        sig = '***' if gls.pvalues[i] < 0.01 else '**' if gls.pvalues[i] < 0.05 else '*' if gls.pvalues[i] < 0.1 else ''
        print(f"    {col:25s}  beta={gls.beta[i]:8.4f}  se={gls.se[i]:.4f}  p={gls.pvalues[i]:.4f}{sig}")
        if col == demo_var:
            result['demo_var'] = col
            result['demo_coef'] = gls.beta[i]
            result['demo_se'] = gls.se[i]
            result['demo_pval'] = gls.pvalues[i]

    return result

# 1a. Z₁ → gross_savings_gdp (reproduce -2.11)
r1a = run_panel_gls(zone_panel, 'gross_savings_gdp', 'Z_1', controls,
                    'Z₁ → Savings (within zone)')

# 1b. working_age_share → gross_savings_gdp
r1b = run_panel_gls(zone_panel, 'gross_savings_gdp', 'working_age_share', controls,
                    'Working-age share → Savings')

# 1c. youth_dep → gross_savings_gdp
r1c = run_panel_gls(zone_panel, 'gross_savings_gdp', 'youth_dep', controls,
                    'Youth dependency → Savings')

# 1d. old_dep → gross_savings_gdp
r1d = run_panel_gls(zone_panel, 'gross_savings_gdp', 'old_dep', controls,
                    'Old-age dependency → Savings')

# 1e. Z₁ → ca_gdp within zone
r1e = run_panel_gls(zone_panel, 'ca_gdp', 'Z_1', controls,
                    'Z₁ → Current Account (within zone)')

# 1f. working_age_share → ca_gdp
r1f = run_panel_gls(zone_panel, 'ca_gdp', 'working_age_share', controls,
                    'Working-age share → Current Account')

# Collect results
for r in [r1a, r1b, r1c, r1d, r1e, r1f]:
    if r is not None:
        round2_results.append(r)

# Interpretation
print("\n  INTERPRETATION:")
if r1a and r1b:
    z1_sign = "negative" if r1a['demo_coef'] < 0 else "positive"
    was_sign = "positive" if r1b['demo_coef'] > 0 else "negative"
    print(f"  Z₁ → Savings: {z1_sign} ({r1a['demo_coef']:.4f})")
    print(f"  Working-age share → Savings: {was_sign} ({r1b['demo_coef']:.4f})")
    if r1a['demo_coef'] < 0 and r1b['demo_coef'] > 0:
        print("  CONFIRMED: Z₁ captures aging (incl. old-age dependency drag on savings)")
        print("  while working-age share captures the demographic dividend itself.")
        print("  Within the $9k-$25k zone, higher Z₁ = more mature = approaching the point")
        print("  where old-age dependency begins to drag savings down, even as working-age")
        print("  share (the dividend) still boosts savings.")


# ═══════════════════════════════════════════════════════════════════════
# 2. FINE-GRAY COMPETING RISK MODEL
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2. COMPETING RISK MODELS (CAUSE-SPECIFIC COX + MULTINOMIAL LOGIT)")
print("=" * 70)
print("Three outcomes: (1) crossed above $25k, (2) fell back below $9k, (0) censored")

# Build survival data with competing events
# Need to identify fell-back countries from crossing_data
# Status labels: 'Entered zone then fell back', 'Fell back below'
# Also use the broader fell-back table from phase7
try:
    fb_table = pd.read_csv('/mnt/c/demographics_capital_flows/development_threshold/output/tables/table17_fell_back.csv')
    fb_isos = set(fb_table['iso3'].tolist())
    print(f"Fell-back countries (from phase7 table17): {len(fb_isos)}")
    print(f"  {sorted(fb_isos)}")
except:
    fb_isos = set()

# Build event data for all zone entrants
surv_records = []
for _, row in zone_entrants.iterrows():
    iso = row['iso3']
    entry_year = row['cross_9k']
    if pd.isna(entry_year):
        continue
    entry_year = int(entry_year)

    c = panel[(panel['iso3'] == iso) & (panel['year'] >= entry_year)].sort_values('year')
    if len(c) == 0:
        continue

    # Determine outcome
    cross_25k = row['cross_25k']
    crossed = row['exited_above'] == 1

    # Check if fell back below $9k at any point after entry
    fell_back = False
    fallback_year = np.nan
    if not crossed:
        below_9k = c[c['gdp_pc_ppp'] < LOWER]
        if len(below_9k) > 0:
            fallback_year = below_9k['year'].min()
            fell_back = True
    # Also check the broader fell-back list (countries that had major declines)
    if iso in fb_isos and not crossed:
        fell_back = True
        if pd.isna(fallback_year):
            # Use trough year from table17 if available
            fb_row = fb_table[fb_table['iso3'] == iso]
            if len(fb_row) > 0:
                trough_yr = fb_row['trough_year'].values[0]
                fallback_year = trough_yr

    # Determine duration and event
    if crossed and pd.notna(cross_25k):
        duration = int(cross_25k) - entry_year
        event = 1  # crossed
    elif fell_back and pd.notna(fallback_year):
        duration = int(fallback_year) - entry_year
        event = 2  # fell back
    else:
        # Censored: still in zone or last observation
        last_year = c['year'].max()
        duration = last_year - entry_year
        event = 0  # censored

    if duration <= 0:
        duration = 1  # minimum 1 year

    surv_records.append({
        'iso3': iso,
        'entry_year': entry_year,
        'duration': duration,
        'event': event,  # 0=censored, 1=crossed, 2=fell_back
        'Z_1_entry': row.get('Z_1_entry', np.nan),
        'gdp_pc_ppp_entry': row.get('gdp_pc_ppp_entry', np.nan),
        'kaopen_entry': row.get('kaopen_entry', np.nan),
        'resource_rents_gdp_entry': row.get('resource_rents_gdp_entry', np.nan),
        'working_age_share_entry': row.get('working_age_share_entry', np.nan),
    })

surv = pd.DataFrame(surv_records)
surv = surv[surv['Z_1_entry'].notna()].copy()

print(f"\nSurvival sample: n={len(surv)}")
print(f"  Crossed (event=1): {(surv['event']==1).sum()}")
print(f"  Fell back (event=2): {(surv['event']==2).sum()}")
print(f"  Censored (event=0): {(surv['event']==0).sum()}")
print(f"\nFell-back countries in survival data:")
for _, row in surv[surv['event']==2].iterrows():
    print(f"  {row['iso3']}: entry={int(row['entry_year'])}, duration={int(row['duration'])}yr, "
          f"Z₁={row['Z_1_entry']:.3f}")

# ─── 2a. Cause-specific Cox for CROSSING (censor fell-back + still-in-zone) ───
print("\n  2a. CAUSE-SPECIFIC COX: Crossing (fell-back censored)")
surv_cross = surv.copy()
surv_cross['status'] = (surv_cross['event'] == 1).astype(int)  # only crossing = event

s = surv_cross[surv_cross['Z_1_entry'].notna() & (surv_cross['duration'] > 0)]
print(f"  Sample: n={len(s)}, events={s['status'].sum()}, censored={(s['status']==0).sum()}")

try:
    cox_cross = PHReg(s['duration'].values, s[['Z_1_entry']].values, status=s['status'].values)
    cox_cross_result = cox_cross.fit()
    hr_cross = np.exp(cox_cross_result.params[0])
    p_cross = cox_cross_result.pvalues[0]
    sig = '***' if p_cross < 0.01 else '**' if p_cross < 0.05 else '*' if p_cross < 0.1 else ''
    print(f"  Z₁ log-HR: {cox_cross_result.params[0]:.4f}, HR={hr_cross:.3f}, p={p_cross:.4f}{sig}")
    round2_results.append({
        'analysis': 'Competing risk: cause-specific Cox',
        'specification': 'Crossing (fell-back censored)',
        'demo_var': 'Z_1_entry',
        'demo_coef': cox_cross_result.params[0],
        'demo_se': cox_cross_result.bse[0],
        'demo_pval': p_cross,
        'n': len(s),
        'note': f'HR={hr_cross:.3f}',
    })
except Exception as e:
    print(f"  Failed: {e}")
    import traceback; traceback.print_exc()

# ─── 2b. Cause-specific Cox for FALLING BACK (censor crossed + still-in-zone) ───
print("\n  2b. CAUSE-SPECIFIC COX: Falling back (crossed censored)")
surv_fb = surv.copy()
surv_fb['status'] = (surv_fb['event'] == 2).astype(int)  # only falling back = event

s_fb = surv_fb[surv_fb['Z_1_entry'].notna() & (surv_fb['duration'] > 0)]
n_fb_events = s_fb['status'].sum()
print(f"  Sample: n={len(s_fb)}, events={n_fb_events}, censored={(s_fb['status']==0).sum()}")

if n_fb_events >= 3:
    try:
        cox_fb = PHReg(s_fb['duration'].values, s_fb[['Z_1_entry']].values, status=s_fb['status'].values)
        cox_fb_result = cox_fb.fit()
        hr_fb = np.exp(cox_fb_result.params[0])
        p_fb = cox_fb_result.pvalues[0]
        sig_fb = '***' if p_fb < 0.01 else '**' if p_fb < 0.05 else '*' if p_fb < 0.1 else ''
        print(f"  Z₁ log-HR: {cox_fb_result.params[0]:.4f}, HR={hr_fb:.3f}, p={p_fb:.4f}{sig_fb}")
        round2_results.append({
            'analysis': 'Competing risk: cause-specific Cox',
            'specification': 'Falling back (crossed censored)',
            'demo_var': 'Z_1_entry',
            'demo_coef': cox_fb_result.params[0],
            'demo_se': cox_fb_result.bse[0],
            'demo_pval': p_fb,
            'n': len(s_fb),
            'note': f'HR={hr_fb:.3f}',
        })
    except Exception as e:
        print(f"  Failed: {e}")
        import traceback; traceback.print_exc()
else:
    print(f"  Too few fell-back events ({n_fb_events}) for Cox model — skipping")
    round2_results.append({
        'analysis': 'Competing risk: cause-specific Cox',
        'specification': 'Falling back (crossed censored)',
        'demo_var': 'Z_1_entry',
        'n': len(s_fb),
        'note': f'Too few events ({n_fb_events}) — model not estimable',
    })

# ─── 2c. Fine-Gray subdistribution hazard approximation ───
print("\n  2c. FINE-GRAY SUBDISTRIBUTION HAZARD (approximation)")
print("  lifelines not available; implementing subdistribution weights manually.")

# Fine-Gray approach: for the crossing event, subjects who experience the
# competing event (fell back) remain in the risk set with time-varying weights
# (IPCW weights). With few competing events, the approximation is:
# keep fell-back subjects in the risk set for the full follow-up period.
# This is equivalent to treating fell-back as censored at end of follow-up.

surv_fg = surv.copy()
# For Fine-Gray: fell-back subjects get duration = max follow-up, status = 0
max_duration = surv_fg['duration'].max()
surv_fg.loc[surv_fg['event'] == 2, 'fg_duration'] = max_duration
surv_fg.loc[surv_fg['event'] == 2, 'fg_status'] = 0
surv_fg.loc[surv_fg['event'] != 2, 'fg_duration'] = surv_fg.loc[surv_fg['event'] != 2, 'duration']
surv_fg.loc[surv_fg['event'] == 1, 'fg_status'] = 1
surv_fg.loc[surv_fg['event'] == 0, 'fg_status'] = 0

s_fg = surv_fg[surv_fg['Z_1_entry'].notna() & (surv_fg['fg_duration'] > 0)]
print(f"  Sample: n={len(s_fg)}, events (crossed)={int(s_fg['fg_status'].sum())}")

try:
    cox_fg = PHReg(s_fg['fg_duration'].values, s_fg[['Z_1_entry']].values,
                   status=s_fg['fg_status'].values.astype(int))
    cox_fg_result = cox_fg.fit()
    hr_fg = np.exp(cox_fg_result.params[0])
    p_fg = cox_fg_result.pvalues[0]
    sig_fg = '***' if p_fg < 0.01 else '**' if p_fg < 0.05 else '*' if p_fg < 0.1 else ''
    print(f"  Z₁ subdist. log-HR: {cox_fg_result.params[0]:.4f}, HR={hr_fg:.3f}, p={p_fg:.4f}{sig_fg}")
    print(f"  Compare: cause-specific HR={hr_cross:.3f}, Fine-Gray approx HR={hr_fg:.3f}")
    round2_results.append({
        'analysis': 'Competing risk: Fine-Gray (approx)',
        'specification': 'Subdistribution hazard for crossing',
        'demo_var': 'Z_1_entry',
        'demo_coef': cox_fg_result.params[0],
        'demo_se': cox_fg_result.bse[0],
        'demo_pval': p_fg,
        'n': len(s_fg),
        'note': f'HR={hr_fg:.3f} (fell-back kept in risk set)',
    })
except Exception as e:
    print(f"  Failed: {e}")
    import traceback; traceback.print_exc()

# ─── 2d. Multinomial logit: cross / stuck / fallback ───
print("\n  2d. MULTINOMIAL LOGIT: P(cross) vs P(stuck) vs P(fallback)")

# Create outcome variable: 0=stuck/censored, 1=crossed, 2=fell_back
mlogit_data = surv[['iso3', 'event', 'Z_1_entry', 'gdp_pc_ppp_entry',
                     'kaopen_entry', 'working_age_share_entry']].dropna(subset=['Z_1_entry']).copy()
mlogit_data['outcome'] = mlogit_data['event']  # 0=censored, 1=crossed, 2=fell_back

print(f"\n  Outcome distribution:")
print(f"    Stuck/censored (0): {(mlogit_data['outcome']==0).sum()}")
print(f"    Crossed (1):        {(mlogit_data['outcome']==1).sum()}")
print(f"    Fell back (2):      {(mlogit_data['outcome']==2).sum()}")

# 2d-i. Z₁ only
print("\n  2d-i. Multinomial logit: Z₁ only")
try:
    from statsmodels.miscmodels.ordinal_model import OrderedModel
except:
    pass

try:
    y_mlogit = mlogit_data['outcome'].values
    X_mlogit = sm.add_constant(mlogit_data[['Z_1_entry']].values)

    mlogit = sm.MNLogit(y_mlogit, X_mlogit)
    mlogit_result = mlogit.fit(disp=0)

    print(f"  Pseudo-R²: {mlogit_result.prsquared:.4f}")
    print(f"  AIC: {mlogit_result.aic:.1f}")
    print(f"\n  Results (base category = 0 = stuck/censored):")
    # MNLogit reports params for each category vs base
    param_names = ['const', 'Z_1_entry']
    for j, cat in enumerate(sorted(mlogit_data['outcome'].unique())):
        if cat == 0:
            continue  # base category
        col_idx = j - 1 if 0 in sorted(mlogit_data['outcome'].unique()) else j
        # MNLogit params shape: (n_params, n_categories-1)
        if col_idx < mlogit_result.params.shape[1]:
            print(f"\n  Category {cat} ({'Crossed' if cat == 1 else 'Fell back'}) vs Stuck:")
            for k, pname in enumerate(param_names):
                coef = mlogit_result.params[k, col_idx]
                se = mlogit_result.bse[k, col_idx]
                pval = mlogit_result.pvalues[k, col_idx]
                sig = '***' if pval < 0.01 else '**' if pval < 0.05 else '*' if pval < 0.1 else ''
                print(f"    {pname:15s}  coef={coef:8.4f}  se={se:.4f}  p={pval:.4f}{sig}")
                if pname == 'Z_1_entry':
                    round2_results.append({
                        'analysis': 'Multinomial logit',
                        'specification': f'P({"Crossed" if cat==1 else "Fell back"}) vs P(Stuck)',
                        'demo_var': 'Z_1_entry',
                        'demo_coef': coef,
                        'demo_se': se,
                        'demo_pval': pval,
                        'n': len(mlogit_data),
                        'note': f'Pseudo-R²={mlogit_result.prsquared:.4f}',
                    })

except Exception as e:
    print(f"  Multinomial logit failed: {e}")
    import traceback; traceback.print_exc()

# 2d-ii. Multinomial logit with controls
print("\n  2d-ii. Multinomial logit: Z₁ + GDP at entry")
mlogit_data2 = mlogit_data.dropna(subset=['gdp_pc_ppp_entry']).copy()
mlogit_data2['log_gdp_entry'] = np.log(mlogit_data2['gdp_pc_ppp_entry'])

try:
    y_ml2 = mlogit_data2['outcome'].values
    X_ml2 = sm.add_constant(mlogit_data2[['Z_1_entry', 'log_gdp_entry']].values)

    mlogit2 = sm.MNLogit(y_ml2, X_ml2)
    mlogit2_result = mlogit2.fit(disp=0)

    print(f"  Pseudo-R²: {mlogit2_result.prsquared:.4f}")
    param_names2 = ['const', 'Z_1_entry', 'log_gdp_entry']
    for j, cat in enumerate(sorted(mlogit_data2['outcome'].unique())):
        if cat == 0:
            continue
        col_idx = j - 1 if 0 in sorted(mlogit_data2['outcome'].unique()) else j
        if col_idx < mlogit2_result.params.shape[1]:
            print(f"\n  Category {cat} ({'Crossed' if cat == 1 else 'Fell back'}) vs Stuck:")
            for k, pname in enumerate(param_names2):
                coef = mlogit2_result.params[k, col_idx]
                se = mlogit2_result.bse[k, col_idx]
                pval = mlogit2_result.pvalues[k, col_idx]
                sig = '***' if pval < 0.01 else '**' if pval < 0.05 else '*' if pval < 0.1 else ''
                print(f"    {pname:15s}  coef={coef:8.4f}  se={se:.4f}  p={pval:.4f}{sig}")
                if pname == 'Z_1_entry' and cat == 1:
                    round2_results.append({
                        'analysis': 'Multinomial logit',
                        'specification': 'P(Crossed) vs P(Stuck) + GDP control',
                        'demo_var': 'Z_1_entry',
                        'demo_coef': coef,
                        'demo_se': se,
                        'demo_pval': pval,
                        'n': len(mlogit_data2),
                        'note': f'Pseudo-R²={mlogit2_result.prsquared:.4f}',
                    })

except Exception as e:
    print(f"  Multinomial logit (controlled) failed: {e}")
    import traceback; traceback.print_exc()

# Summary comparison
print("\n  COMPETING RISK SUMMARY:")
print(f"  {'Model':45s} {'HR/coef':>10s} {'p-value':>10s}")
print("  " + "-" * 68)
print(f"  {'Standard Cox (from paper, univariate)':45s} {'2.400':>10s} {'<0.05':>10s}")
try:
    print(f"  {'Cause-specific Cox (crossing, fell-back censored)':45s} {hr_cross:10.3f} {p_cross:10.4f}")
except:
    pass
try:
    print(f"  {'Cause-specific Cox (falling back, crossed censored)':45s} {hr_fb:10.3f} {p_fb:10.4f}")
except:
    print(f"  {'Cause-specific Cox (falling back)':45s} {'N/A':>10s} {'N/A':>10s}")
try:
    print(f"  {'Fine-Gray approx (subdistribution hazard)':45s} {hr_fg:10.3f} {p_fg:10.4f}")
except:
    pass


# ═══════════════════════════════════════════════════════════════════════
# 3. ALTERNATIVE THRESHOLD CROSSING COUNTS
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("3. ALTERNATIVE THRESHOLD CROSSING COUNTS RECONCILIATION")
print("=" * 70)
print("Referee notes inconsistent n_crossed counts in robustness table.")
print("Clarifying: 'entered from below AND crossed above' vs 'ever crossed upper bound'.\n")

threshold_defs = [
    (7000, 20000),
    (8000, 22000),
    (9000, 25000),
    (10000, 30000),
    (12000, 30000),
]

count_rows = []
for lower, upper in threshold_defs:
    label = f"${lower//1000}k-${upper//1000}k"

    # For each country, track trajectory
    ever_entered = 0
    entered_from_below = 0
    crossed_above = 0
    entered_below_then_crossed = 0
    stuck_in_zone = 0
    fell_back = 0
    still_below = 0

    country_details = {'entered_below_crossed': [], 'crossed_from_zone': [],
                       'entered_from_below_only': [], 'fell_back': []}

    for iso in panel['iso3'].unique():
        c = panel[panel['iso3'] == iso].sort_values('year')
        gdp = c[c['gdp_pc_ppp'].notna()][['year', 'gdp_pc_ppp']]
        if len(gdp) < 5:
            continue

        first_gdp = gdp['gdp_pc_ppp'].iloc[0]
        last_gdp = gdp['gdp_pc_ppp'].iloc[-1]
        min_gdp = gdp['gdp_pc_ppp'].min()
        max_gdp = gdp['gdp_pc_ppp'].max()

        # a) Ever in zone
        in_zone = ((gdp['gdp_pc_ppp'] >= lower) & (gdp['gdp_pc_ppp'] <= upper)).any()
        if in_zone or max_gdp > upper:
            # Was ever at or above the lower bound
            pass

        # More precise: ever had GDP in [lower, upper]
        ever_in = ((gdp['gdp_pc_ppp'] >= lower) & (gdp['gdp_pc_ppp'] <= upper)).any()
        # Or passed through the zone entirely
        if first_gdp < lower and max_gdp > upper:
            ever_in = True

        if ever_in:
            ever_entered += 1

        # b) Entered from below: was below lower, then rose to >= lower
        was_below = (gdp['gdp_pc_ppp'] < lower).any()
        rose_to_zone = (gdp['gdp_pc_ppp'] >= lower).any()
        if was_below and rose_to_zone:
            # Check temporal order: first below, then in/above zone
            first_below_year = gdp[gdp['gdp_pc_ppp'] < lower]['year'].min()
            first_zone_year = gdp[gdp['gdp_pc_ppp'] >= lower]['year'].min()
            if first_below_year < first_zone_year:
                entered_from_below += 1

                # d) Entered from below AND crossed above upper
                above_upper = gdp[gdp['gdp_pc_ppp'] > upper]
                if len(above_upper) > 0:
                    entered_below_then_crossed += 1
                    country_details['entered_below_crossed'].append(iso)
            elif was_below:
                # Started in/above zone, fell below, came back
                entered_from_below += 1
                above_upper = gdp[gdp['gdp_pc_ppp'] > upper]
                if len(above_upper) > 0:
                    entered_below_then_crossed += 1
                    country_details['entered_below_crossed'].append(iso)

        # c) Crossed above upper (from any starting point)
        if (gdp['gdp_pc_ppp'] > upper).any():
            crossed_above += 1
            if not (was_below and first_below_year < first_zone_year if was_below and rose_to_zone else False):
                country_details['crossed_from_zone'].append(iso)

    print(f"\n  Threshold: {label}")
    print(f"    a) Ever entered zone:                          {ever_entered:4d}")
    print(f"    b) Entered from below (was <${lower//1000}k):           {entered_from_below:4d}")
    print(f"    c) Ever crossed above ${upper//1000}k:                  {crossed_above:4d}")
    print(f"    d) Entered from below AND crossed above:       {entered_below_then_crossed:4d}")
    if lower == 9000 and upper == 25000:
        print(f"\n    Countries that entered from below AND crossed (d):")
        for iso in sorted(country_details['entered_below_crossed']):
            print(f"      {iso}")

    count_rows.append({
        'threshold': label,
        'lower': lower,
        'upper': upper,
        'ever_entered_zone': ever_entered,
        'entered_from_below': entered_from_below,
        'ever_crossed_above': crossed_above,
        'below_to_above': entered_below_then_crossed,
    })

counts_df = pd.DataFrame(count_rows)
print("\n  CROSSING COUNT RECONCILIATION TABLE:")
print(counts_df.to_string(index=False))

print("\n  EXPLANATION OF DISCREPANCY:")
print("  'n_crossed' in the paper = countries that ever crossed above the upper bound")
print("  (count c), which includes countries that started inside the zone.")
print("  The smaller count (d) = full transit from below lower to above upper.")
print("  Both are valid but measure different things:")
print("  (c) captures all successful crossings regardless of origin")
print("  (d) captures the 'development miracle' — full transit through the zone")

# Add counts to results
for row in count_rows:
    round2_results.append({
        'analysis': 'Threshold counts',
        'specification': row['threshold'],
        'n': row['ever_entered_zone'],
        'note': f"entered_below={row['entered_from_below']}, "
                f"crossed_above={row['ever_crossed_above']}, "
                f"below_to_above={row['below_to_above']}",
    })


# ═══════════════════════════════════════════════════════════════════════
# SAVE OUTPUT TABLE
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("OUTPUT TABLE")
print("=" * 70)

results_df = pd.DataFrame(round2_results)
outpath = '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table19_round2.csv'
results_df.to_csv(outpath, index=False)
print(f"\nSaved to: {outpath}")
print(results_df.to_string(index=False))

print("\n" + "=" * 70)
print("PHASE 9 (ROUND 2 REVISIONS) COMPLETE")
print("=" * 70)
